Optimizes Native FSDP memory_efficient_init weight loading for multi-node EP/FSDP jobs & add mutlti-node scripts#207
Merged
Conversation
added 3 commits
May 29, 2026 15:47
Load full pretrained weights on each node's local rank0 and distribute shards only within the node, reducing global rank0 pressure for large EP/FSDP jobs.
Contributor
There was a problem hiding this comment.
Code Review
This pull request removes the FSDP2 state-dict loading patch from the Accelerate strategy and implements node-local state-dict loading and scattering in the Native FSDP strategy using local rank topology. The review feedback suggests optimizing the node-local communication by creating a node-local process group once and using dist.broadcast instead of inefficient point-to-point dist.send and dist.recv calls for every parameter. Additionally, it is recommended to verify that distributed training is initialized before retrieving the local rank to prevent runtime errors in non-distributed environments.
Implement `get_adapter_state_dict` methods in AccelerateStrategy and NativeFSDPStrategy to efficiently collect only LoRA adapter parameters, avoiding full model state dict collection for large FSDP/EP jobs. The NativeFSDP version includes EP-aware all-gather for expert parameters.
added 3 commits
June 1, 2026 16:03
The previous implementation used individual send/recv operations for each target rank, which was inefficient and could cause performance bottlenecks. This change replaces them with a single broadcast call using a new local group, improving communication efficiency and reducing code complexity.
Avoid HCCL subgroup broadcast for node-local weight loading, since dynamic subgroup communicators can fail on NPU. Fall back to send/recv on HCCL while keeping local broadcast for other backends.
tpx818
reviewed
Jun 1, 2026
tpx818
approved these changes
Jun 2, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR type
PR information
This PR improves EP + FSDP2 + LoRA training for DeepSeek V4 Flash in multi-node environments.
Main changes:
Optimize Native FSDP
memory_efficient_initweight loading for multi-node EP/FSDP jobs.Previously, global rank 0 distributed pretrained weights and EP expert shards to all ranks. With large
world_size, especially whenep_fsdp_size=1, this put heavy communication pressure on global rank 0.This PR changes the loading path so each node's local rank 0 loads/captures the full pretrained state and distributes tensors only to ranks on the same node. EP shard selection still uses
rank_to_ep_rank, so each target rank receives the correct EP slice before FSDP/DTensor placement.Optimize LoRA checkpoint saving.
LoRA adapter saving no longer calls full-model
get_full_state_dict()and then filters LoRA keys. Instead, it collects only LoRA adapter parameters. Native FSDP keeps EP-aware LoRA all-gather, avoiding large base-model state_dict materialization during adapter checkpoint saving.Remove the Twinkle-side Accelerate FSDP2 state-dict loading monkey patch.
AccelerateStrategynow relies on native Accelerate behavior formemory_efficient_init/cpu_ram_efficient_loading.Add multi-node DeepSeek V4 Flash EP + FSDP2 + LoRA cookbook script.
The DeepSeek V4 LoRA cookbook now supports configuring GPU/NPU count via
NUM_GPUS.Experiment results
[2026-06-01 07:35:49][INFO:twinkle] Current is step 4 of 16, metric: {'loss': '3.0792', 'learning rate(param group 1)': '0.000000e+00', 'learning rate(param group 2)': '0.000000e+00', 'iters': 0, 'total time elapse': '2.9 minutes', 'speed': '0.00 iters/s'}
[2026-06-01 07:36:16][INFO:twinkle] Current is step 8 of 16, metric: {'loss': '2.9792', 'grad_norm': '135.397873', 'learning rate(param group 1)': '2.000000e-05', 'learning rate(param group 2)': '2.000000e-05', 'iters': 1, 'total time elapse': '200 seconds', 'speed': '0.04 iters/s'}
[2026-06-01 07:36:51][INFO:twinkle] Current is step 12 of 16, metric: {'loss': '3.0201', 'grad_norm': '136.174759', 'learning rate(param group 1)': '4.000000e-05', 'learning rate(param group 2)': '4.000000e-05', 'iters': 2, 'total time elapse': '236 seconds', 'speed': '0.03 iters/s'}
[2026-06-01 07:37:22][INFO:twinkle] Saved final adapter to /nas/diskz/checkpoint-final